{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7637\n",
      "1626\n",
      "1646\n"
     ]
    }
   ],
   "source": [
    "#定义数据\n",
    "class SurnameDataset(Dataset):\n",
    "    def __init__(self, part):\n",
    "        data = pd.read_csv('./data/surnames/数字化数据.csv')\n",
    "        data = data[data.part == part]\n",
    "\n",
    "        #去掉少于3个字符的名字\n",
    "        def filter_by_len(line):\n",
    "            return len(line.x.split(',')) >= 3\n",
    "\n",
    "        data = data[data.apply(filter_by_len, axis=1)]\n",
    "\n",
    "        self.data = data\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.data.iloc[i, 0], self.data.iloc[i, 1]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "\n",
    "train_dataset = SurnameDataset(part='train')\n",
    "val_dataset = SurnameDataset(part='val')\n",
    "test_dataset = SurnameDataset(part='test')\n",
    "\n",
    "print(len(train_dataset))\n",
    "print(len(val_dataset))\n",
    "print(len(test_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0,  0,  0,  0,  0,  0,  0, 20,  3,  5, 10, 17,  2, 20],\n",
      "        [ 0,  0,  0,  0,  0,  0,  0,  0,  0, 14,  6, 13, 10, 10],\n",
      "        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 13, 27,  3, 18]]) torch.Size([100, 14])\n",
      "tensor([ 3,  3, 12]) torch.Size([100])\n"
     ]
    }
   ],
   "source": [
    "def to_tensor(data):\n",
    "    N = len(data)\n",
    "    #N句话,每句话14个词\n",
    "    xs = np.zeros((N, 14))\n",
    "    #尾字母\n",
    "    ys = np.empty(N)\n",
    "\n",
    "    for i in range(N):\n",
    "        x, y = data[i]\n",
    "\n",
    "        x = x.split(',')\n",
    "        \n",
    "        #取x的最后一个字母作为y\n",
    "        ys[i] = x[-1]\n",
    "\n",
    "        #x去掉最后一个字母\n",
    "        x = x[:-1]\n",
    "        \n",
    "        #反转后补0\n",
    "        x = x[::-1] + ['0'] * 14\n",
    "        \n",
    "        #切割到14位长度\n",
    "        x = x[:14]\n",
    "        \n",
    "        #反转回来\n",
    "        x = x[::-1]\n",
    "        xs[i] = x\n",
    "\n",
    "    return torch.LongTensor(xs), torch.LongTensor(ys)\n",
    "\n",
    "\n",
    "#数据加载器\n",
    "train_dataloader = DataLoader(dataset=train_dataset,\n",
    "                              batch_size=100,\n",
    "                              shuffle=True,\n",
    "                              drop_last=True,\n",
    "                              collate_fn=to_tensor)\n",
    "\n",
    "val_dataloader = DataLoader(dataset=val_dataset,\n",
    "                            batch_size=100,\n",
    "                            shuffle=True,\n",
    "                            drop_last=True,\n",
    "                            collate_fn=to_tensor)\n",
    "\n",
    "test_dataloader = DataLoader(dataset=test_dataset,\n",
    "                             batch_size=100,\n",
    "                             shuffle=True,\n",
    "                             drop_last=True,\n",
    "                             collate_fn=to_tensor)\n",
    "\n",
    "#遍历数据\n",
    "sample = None\n",
    "for i, data in enumerate(train_dataloader):\n",
    "    sample = data\n",
    "    x, y = data\n",
    "    print(x[:3], x.shape)\n",
    "    print(y[:3], y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0758,  0.1360, -0.0190,  ..., -0.1311,  0.0907, -0.0893],\n",
       "        [ 0.1325,  0.0937, -0.0643,  ..., -0.1381, -0.0031, -0.0068],\n",
       "        [ 0.0105,  0.0443, -0.0878,  ..., -0.0548, -0.0068, -0.0708],\n",
       "        ...,\n",
       "        [-0.0507,  0.1098, -0.0063,  ..., -0.0687,  0.0334, -0.1521],\n",
       "        [ 0.0861,  0.1112,  0.1128,  ..., -0.0808,  0.0130, -0.1882],\n",
       "        [ 0.0029,  0.1162, -0.0275,  ...,  0.0031, -0.0506, -0.0095]],\n",
       "       grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义网络模型\n",
    "class SurnameClassifier(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SurnameClassifier, self).__init__()\n",
    "\n",
    "        self.embedding = nn.Embedding(num_embeddings=30,\n",
    "                                      embedding_dim=50,\n",
    "                                      padding_idx=0)\n",
    "\n",
    "        self.rnn = nn.GRU(input_size=50, hidden_size=100, batch_first=True)\n",
    "\n",
    "        self.fc1 = nn.Linear(in_features=100, out_features=100)\n",
    "        self.fc2 = nn.Linear(in_features=100, out_features=30)\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        #[b,14] -> [b,14,50]\n",
    "        embed = self.embedding(x)\n",
    "\n",
    "        #[b,14,50] -> [b,14,100],[1,b,100]\n",
    "        out, h = self.rnn(embed)\n",
    "\n",
    "        #[b,100] -> [b,30]\n",
    "        out = F.relu(self.fc1(F.dropout(h.squeeze(), 0.2)))\n",
    "        out = self.fc2(F.dropout(out, 0.2))\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "model = SurnameClassifier()\n",
    "model(sample[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.018125"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def test(dataloader):\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for i, data in enumerate(dataloader):\n",
    "        x, y = data\n",
    "\n",
    "        y_pred = model(x)\n",
    "        y_pred = y_pred.argmax(dim=1)\n",
    "\n",
    "        correct += (y_pred == y).sum().item()\n",
    "        total += len(y)\n",
    "\n",
    "    return correct / total\n",
    "\n",
    "\n",
    "test(val_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 2.2973945140838623 0.300625\n",
      "1 1.9206297397613525 0.401875\n",
      "2 1.6891889572143555 0.4375\n",
      "3 1.960289478302002 0.47375\n",
      "4 1.8121192455291748 0.484375\n",
      "5 1.6147767305374146 0.5\n",
      "6 1.5749751329421997 0.5175\n",
      "7 1.4665619134902954 0.535625\n",
      "8 1.4422262907028198 0.540625\n",
      "9 1.2665151357650757 0.554375\n",
      "10 1.5351853370666504 0.56375\n",
      "11 1.1230456829071045 0.5675\n",
      "12 1.4524245262145996 0.578125\n",
      "13 1.3921475410461426 0.566875\n",
      "14 0.9999076128005981 0.57375\n",
      "15 1.428695559501648 0.59125\n",
      "16 1.2512894868850708 0.58125\n",
      "17 0.9031326174736023 0.57625\n",
      "18 0.9328461289405823 0.58\n",
      "19 0.9816864728927612 0.58375\n"
     ]
    }
   ],
   "source": [
    "loss_func = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "model.train()\n",
    "for epoch in range(20):\n",
    "    for i, data in enumerate(train_dataloader):\n",
    "        x, y = data\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(x)\n",
    "\n",
    "        loss = loss_func(y_pred, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    if epoch % 1 == 0:\n",
    "        accurecy = test(val_dataloader)\n",
    "        print(epoch, loss.item(), accurecy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.57"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(test_dataloader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "138px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": "5",
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
